Coordinate Attention Paper Reading

Coordinate Attention for Efficient Mobile Network Design

CVPR2021

Motivation:

Baseline: SE attention, CBAM
CBAM:
image|690x218,50% image|690x369,50%
就是CA以及SA的结合attention。

Pipeline对比:
image|690x250
前面的两者都没有很好的考虑空间结构的相关性。

  • SE 使用GAP将每一个通道化为一个标量,能保留全局信息,但是lack of 结构信息。
  • CBAM 使用两个阶段来进行attention,一个是Channle-wise的也是同SE一样没有结构信息。另一个是Spatial Attention Module,这个模块将所有的channle 压缩到一个维度中,损失了部分信息。
  • Coordinate attetnion: 在X,Y方向上做Pooling操作,这样一方面没有损失掉空间的信息,X 过Pool后可以用Y Pool的结果来弥补。也没有丢失整个Channle的信息。

Method:

image|663x267,75%
类比于SE中计算的整个单个Channle的均值或最大值,这里对feature $\mathbb{R}^{C \times H \times W}$ 进行对X轴、Y轴的均值计算,得到 $Z^w \in \mathbb{R}^{C \times 1 \times W} $ 和 $Z^h \in \mathbb{R}^{C \times H \times 1} $ 将这两个特征Cat起来经过Conv,进行BN和非线性激活:
image|262x63
再split后,分别过两个Conv,其中 $F_h , F_w$ 是将 $f^h$ 通道数与input的feature一样。(之前Channle 有ratio)
image|228x96
最后得到结合X,Y的方向信息的特征:
image|658x195

Results:

image|632x500,50% image|580x500,50%
image|690x255

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class h_sigmoid(nn.Module):
def __init__(self, inplace=True):
super(h_sigmoid, self).__init__()
self.relu = nn.ReLU6(inplace=inplace)

def forward(self, x):
return self.relu(x + 3) / 6

class h_swish(nn.Module):
def __init__(self, inplace=True):
super(h_swish, self).__init__()
self.sigmoid = h_sigmoid(inplace=inplace)

def forward(self, x):
return x * self.sigmoid(x)

class CoordAtt(nn.Module):
def __init__(self, inp, oup, reduction=32):
super(CoordAtt, self).__init__()
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))

mip = max(8, inp // reduction)

self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
self.bn1 = nn.BatchNorm2d(mip)
self.act = h_swish()

self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)


def forward(self, x):
identity = x

n,c,h,w = x.size()
x_h = self.pool_h(x)
x_w = self.pool_w(x).permute(0, 1, 3, 2)

y = torch.cat([x_h, x_w], dim=2)
y = self.conv1(y)
y = self.bn1(y)
y = self.act(y)

x_h, x_w = torch.split(y, [h, w], dim=2)
x_w = x_w.permute(0, 1, 3, 2)

a_h = self.conv_h(x_h).sigmoid()
a_w = self.conv_w(x_w).sigmoid()

out = identity * a_w * a_h

return out